import sys
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import sys
sys.path.append('/content/drive/MyDrive/final')
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
print('Not connected to a GPU')
else:
print(gpu_info)
Wed Dec 8 16:12:07 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44 Driver Version: 460.32.03 CUDA Version: 11.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |
| N/A 40C P0 27W / 250W | 0MiB / 16280MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
pip install tensorflow
Requirement already satisfied: tensorflow in /usr/local/lib/python3.7/dist-packages (2.7.0) Requirement already satisfied: libclang>=9.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (12.0.0) Requirement already satisfied: h5py>=2.9.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (3.1.0) Requirement already satisfied: tensorflow-estimator<2.8,~=2.7.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (2.7.0) Requirement already satisfied: gast<0.5.0,>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (0.4.0) Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (3.10.0.2) Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.6.3) Requirement already satisfied: wheel<1.0,>=0.32.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (0.37.0) Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (3.17.3) Requirement already satisfied: absl-py>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (0.12.0) Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (2.0) Requirement already satisfied: keras-preprocessing>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.1.2) Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.13.3) Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (3.3.0) Requirement already satisfied: keras<2.8,>=2.7.0rc0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (2.7.0) Requirement already satisfied: tensorboard~=2.6 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (2.7.0) Requirement already satisfied: numpy>=1.14.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.19.5) Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (0.2.0) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.42.0) Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.15.0) Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (1.1.0) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow) (0.22.0) Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py>=2.9.0->tensorflow) (1.5.2) Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (57.4.0) Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (1.35.0) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (0.4.6) Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (2.23.0) Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (1.0.1) Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (0.6.1) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (1.8.0) Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.6->tensorflow) (3.3.6) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow) (4.8) Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow) (4.2.4) Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow) (1.3.0) Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.6->tensorflow) (4.8.2) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard~=2.6->tensorflow) (3.6.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard~=2.6->tensorflow) (0.4.8) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow) (2.10) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow) (2021.10.8) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.6->tensorflow) (3.0.4) Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow) (3.1.1)
import csv
import json
import os
from PIL import Image
import pprint
import re
import copy
import skimage.draw
import visualize
import utils as my_utils
import matplotlib.pyplot as plt
import cv2
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from IPython.display import clear_output
# global context variables
segmentation_path = '/content/drive/MyDrive/final/data/segmentation_masks' # a folder of annotation files
img_path = '/content/drive/MyDrive/final/data/images' # a folder of image files
class img_meta:
# img_name: image file name
# annotations: the annotation list
def __str__(self):
return 'img_name: ' + self.img_name + ', mask_name: \n' + self.mask_name
def __repr__(self):
return 'img_name: ' + self.img_name + ', mask_name: \n' + self.mask_name
def process(self, img_name, segmentation_path):
self.img_name = img_name
img_tokens = img_name.strip('.png')
img_tokens = img_tokens.split('_')
img_key = img_tokens[0]
img_name_key = img_tokens[1]
# print(f"Image name_key: {img_name_key}")
# print(f"Image key: {img_key}")
# Check all segmentation masks
mask_files = os.listdir(segmentation_path)
for file in mask_files:
file_copy = copy.deepcopy(file)
file_copy = re.sub(r"\.[0-9]", "", file_copy)
mask_tokens = file_copy.split('_')
mask_name_key = mask_tokens[0]
for token in mask_tokens:
if len(token)<=2:
mask_key = token
break
# print("Mask Key: ", mask_key)
if img_key == mask_key:
self.mask_name = file
break
#This function uses opencv rather than PIL to load image.
def load_image(self):
img_file_path = os.path.join(img_path, self.img_name)
"""Load the specified image and return a [H,W,3] Numpy array.
"""
# Load image
image = cv2.imread(img_file_path)
# If grayscale. Convert to RGB for consistency.
if image.ndim != 3:
image = skimage.color.gray2rgb(image)
# If has an alpha channel, remove it for consistency
if image.shape[-1] == 4:
image = image[..., :3]
return image
def load_mask(self):
mask_file_path = os.path.join(segmentation_path, self.mask_name)
# Load mask
mask = cv2.imread(mask_file_path, cv2.IMREAD_GRAYSCALE)
# Normalize mask between (0-1)
mask = np.where(mask>0, 1, 0)
return mask
img_metas = []
for file in os.listdir(img_path):
if file.endswith('.png'):
print('Processing', file, '...')
img_meta_obj = img_meta()
img_meta_obj.process(file, segmentation_path)
img_metas.append(img_meta_obj)
else:
print('[WARNING] Non-image file detected:', file)
Processing 3_akulasaisanjay.png ... Processing 1_ahsanmuhammad.png ... Processing 3_anandankit.png ... Processing 4_anandankit.png ... Processing 1_pandirinikhilkumar.png ... Processing 5_peddinenigowtham.png ... Processing 5_patipravasini.png ... Processing 4_patipravasini.png ... Processing 1_adhikarisaugat.png ... Processing 2_ahsanmuhammad.png ... Processing 2_akulasaisanjay.png ... Processing 5_jammulavathsal.png ... Processing 11_bachamollaharshavardhanreddy.png ... Processing 16_bonuvarshinireddy.png ... Processing 7_atlachandrasekharreddy.png ... Processing 11_rajakvarunraju.png ... Processing 16_rayallavineethabhavya.png ... Processing 10_bachamollaharshavardhanreddy.png ... Processing 8_atlachandrasekharreddy.png ... Processing 6_peddinenigowtham.png ... Processing 12_rajakvarunraju.png ... Processing 11_rabbyshahariar.png ... Processing 17_lakkireddybrunda.png ... Processing 15_bonuvarshinireddy.png ... Processing 6_anwarmuhammadjunaid.png ... Processing 14_kongarachanikya.png ... Processing 18_sahapratim.png ... Processing 17_runjalasunithaglory.png ... Processing 15_rayallavineethabhavya.png ... Processing 13_kollasaividwan.png ... Processing 11_karnasaimanishreddy.png ... Processing 10_kanugovimahidhar.png ... Processing 7_anwarmuhammadjunaid.png ... Processing 10_rabbyshahariar.png ... Processing 18_bungmahesh.png ... Processing 12_keshidianudeep.png ... Processing 18_runjalasunithaglory.png ... Processing 6_junnuthulapranayreddy.png ... Processing 19_bungmahesh.png ... Processing 24_chikkalasaikiran.png ... Processing 34_duddukurileelakrishna.png ... Processing 21_seelamsumanthkumarreddy.png ... Processing 21_chaudharyamitlakshmikant.png ... Processing 21_cheniminenihemanthi.png ... Processing 27_dasarkadeep.png ... Processing 22_shakerbilawal.png ... Processing 34_valasanivenugopal.png ... Processing 30_divvelapraveen.png ... Processing 21_shakerbilawal.png ... Processing 34_dykenlandon.png ... Processing 26_dantuluripretham.png ... Processing 27_dantuluripretham.png ... Processing 25_chintamvamsikrishna.png ... Processing 35_dykenlandon.png ... Processing 28_dasarkadeep.png ... Processing 22_cheniminenihemanthi.png ... Processing 20_lewisseth.png ... Processing 33_duddukurileelakrishna.png ... Processing 24_chintamvamsikrishna.png ... Processing 33_valasanivenugopal.png ... Processing 20_seelamsumanthkumarreddy.png ... Processing 20_chaudharyamitlakshmikant.png ... Processing 19_sahapratim.png ... Processing 23_chikkalasaikiran.png ... Processing 48_pandirinikhilkumar.png ... Processing 48_adhikarisaugat.png ... Processing 44_guragainbijay.png ... Processing 44_gundlapallybhanuteja.png ... Processing 36_vodapallikalyani.png ... Processing 40_goturisaikarthikreddy.png ... Processing 52_patipravasini.png ... Processing 39_gorantlasaikrishnavarma.png ... Processing 41_olajideemmanuelolamide.png ... Processing 53_jammulavathsal.png ... Processing 49_ahsanmuhammad.png ... Processing 46_pampanaaditya.png ... Processing 45_guragainbijay.png ... Processing 39_goturisaikarthikreddy.png ... Processing 51_anandankit.png ... Processing 43_pagarepranav.png ... Processing 43_gundlapallybhanuteja.png ... Processing 42_gundebommualekhya.png ... Processing 38_wadoodhamid.png ... Processing 47_indukurisubhavarshitha.png ... Processing 47_panchalsagarlaxmikant.png ... Processing 48_indukurisubhavarshitha.png ... Processing 43_gundebommualekhya.png ... Processing 36_gajamcharantej.png ... Processing 52_jammulavathsal.png ... Processing 38_gorantlasaikrishnavarma.png ... Processing 50_akulasaisanjay.png ... Processing 59_rajakvarunraju.png ... Processing 63_bonuvarshinireddy.png ... Processing 58_karnasaimanishreddy.png ... Processing 60_keshidianudeep.png ... Processing 62_kongarachanikya.png ... Processing 59_keshidianudeep.png ... Processing 65_lakkireddybrunda.png ... Processing 58_rabbyshahariar.png ... Processing 64_lakkireddybrunda.png ... Processing 57_kanugovimahidhar.png ... Processing 58_bachamollaharshavardhanreddy.png ... Processing 66_bungmahesh.png ... Processing 65_runjalasunithaglory.png ... Processing 61_kollasaividwan.png ... Processing 53_junnuthulapranayreddy.png ... Processing 58_kanugovimahidhar.png ... Processing 60_kollasaividwan.png ... Processing 67_lewisseth.png ... Processing 66_sahapratim.png ... Processing 61_kongarachanikya.png ... Processing 55_atlachandrasekharreddy.png ... Processing 59_karnasaimanishreddy.png ... Processing 53_peddinenigowtham.png ... Processing 63_rayallavineethabhavya.png ... Processing 81_duddukurileelakrishna.png ... Processing 69_cheniminenihemanthi.png ... Processing 68_lokahrishikeshreddy.png ... Processing 74_mallempallicharankumar.png ... Processing 87_goturisaikarthikreddy.png ... Processing 75_mallempallicharankumar.png ... Processing 69_lokahrishikeshreddy.png ... Processing 84_gajamcharantej.png ... Processing 76_dayanasri.png ... Processing 68_seelamsumanthkumarreddy.png ... Processing 85_wadoodhamid.png ... Processing 69_shakerbilawal.png ... Processing 86_gorantlasaikrishnavarma.png ... Processing 82_dykenlandon.png ... Processing 81_valasanivenugopal.png ... Processing 84_vodapallikalyani.png ... Processing 72_chintamvamsikrishna.png ... Processing 71_chikkalasaikiran.png ... Processing 68_lewisseth.png ... Processing 68_chaudharyamitlakshmikant.png ... Processing 74_dantuluripretham.png ... Processing 83_nadimpallisriramakesavaakhilvarma.png ... Processing 95_adhikarisaugat.png ... Processing 95_indukurisubhavarshitha.png ... Processing 92_guragainbijay.png ... Processing 91_gundlapallybhanuteja.png ... Processing 89_olajideemmanuelolamide.png ... Processing 94_pampanaaditya.png ... Processing 90_pagarepranav.png ... Processing 90_gundebommualekhya.png ... Processing 88_oladriarchana.png ... Processing 88_olajideemmanuelolamide.png ... Processing 93_pampanaaditya.png ... Processing 91_pagarepranav.png ... Processing 95_pandirinikhilkumar.png ...
print(len(img_metas))
img_meta_obj = img_metas[1]
print(img_meta_obj)
150 img_name: 1_ahsanmuhammad.png, mask_name: ahsanmuhammad_6028238_67421715_1_segment_mask-3.png
img_obj = img_meta_obj.load_image()
print(img_obj.shape)
visualize.display_images([img_obj])
(387, 362, 3)
# Visualize a random file
colors = visualize.random_colors(5)
color = colors[0]
#sample_idx = np.random.choice(len(img_metas))
sample_idx = 1
print("Sample Index: ", sample_idx)
#Image Meta Data
img_meta_obj = img_metas[sample_idx]
img_file = img_meta_obj.load_image()
print("Image shape: ", img_file.shape)
mask = img_meta_obj.load_mask()
print("Mask shape: ", mask.shape)
Sample Index: 1 Image shape: (387, 362, 3) Mask shape: (387, 362)
masked_image = visualize.apply_mask(img_file, mask, color)
visualize.display_images([masked_image])
# Image size that we are going to use
IMG_SIZE = 256
# Our images are RGB (3 channels)
N_CHANNELS = 3
# Foreground and backgroung
N_CLASSES = 2
#Loads the data and creates a tuple of images
#masks to be later made into Tensorflow data object
def dataLoader():
img_list = []
mask_list = []
for i in range(len(img_metas)):
img_meta_obj = img_metas[i]
img_file = img_meta_obj.load_image()
#print("Image shape: ", img_file.shape)
#Convert mask to the shape of an image (3D) with channel=1
mask = img_meta_obj.load_mask()
mask = np.expand_dims(mask, axis =-1)
#Rezize mask and image to common size
# Resize Image to 256x256
img_file, _, scale, padding,_ = my_utils.resize_image(img_file, max_dim = IMG_SIZE, mode="square")
#print("Image shape: ", img_file.shape)
# Resize Mask to 256x256
mask = my_utils.resize_mask(mask, scale, padding)
#Fix Overlap bug
mask = np.where(mask > 0, 1, 0)
mask = mask.astype(np.float32)
#print(np.amax(mask))
#print("Mask shape: ", mask.shape)
#Normalize Image
img_file = img_file.astype(np.float32)
img_file = img_file/255.0
#print("Image shape: ", img_file.shape)
img_list.append(img_file)
mask_list.append(mask)
return (img_list, mask_list)
dataset = tf.data.Dataset.from_tensor_slices(dataLoader())
train_size = int(0.9 * len(img_metas))
val_size = int(0.1 * len(img_metas))
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)
BATCH_SIZE = 5
BUFFER_SIZE = 1000
# important for reproducibility
# this allows to generate the same random numbers
SEED = 42
AUTOTUNE = tf.data.experimental.AUTOTUNE
dataset = {"train": train_dataset, "val": val_dataset}
# -- Train Dataset --#
#dataset['train'] = dataset['train'].shuffle(buffer_size=BUFFER_SIZE, seed=SEED)
dataset['train'] = dataset['train'].repeat()
dataset['train'] = dataset['train'].batch(BATCH_SIZE)
dataset['train'] = dataset['train'].prefetch(buffer_size=AUTOTUNE)
#-- Validation Dataset --#
dataset['val'] = dataset['val'].repeat()
dataset['val'] = dataset['val'].batch(BATCH_SIZE)
dataset['val'] = dataset['val'].prefetch(buffer_size=AUTOTUNE)
print(dataset['train'])
print(dataset['val'])
# how shuffle works: https://stackoverflow.com/a/53517848
<PrefetchDataset shapes: ((None, 256, 256, 3), (None, 256, 256, 1)), types: (tf.float32, tf.float32)> <PrefetchDataset shapes: ((None, 256, 256, 3), (None, 256, 256, 1)), types: (tf.float32, tf.float32)>
def display_sample(display_list):
"""Show side-by-side an input image,
the ground truth and the prediction.
"""
plt.figure(figsize=(18, 18))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
for image, mask in dataset['train'].take(4):
sample_image, sample_mask = image, mask
display_sample([sample_image[1], sample_mask[1]])
from keras import backend as K
from keras.layers import concatenate, Conv2DTranspose, Activation
from keras.layers import BatchNormalization
from keras.layers import Conv2D, Input, AvgPool2D
from keras.models import Model
#from keras.layers.convolutional import AtrousConvolution2D
# -- Keras Functional API -- #
# -- UNet++ Implementation -- #
# Everything here is from tensorflow.keras.layers
# I imported tensorflow.keras.layers * to make it easier to read
dropout_rate = 0.5
input_shape = (IMG_SIZE, IMG_SIZE, N_CHANNELS)
n_labels = 2
def conv_batchnorm_relu_block(input_tensor, nb_filter, kernel_size=3):
x = Conv2D(nb_filter, (kernel_size, kernel_size), padding='same')(input_tensor)
x = BatchNormalization(axis=2)(x)
x = Activation('relu')(x)
return x
nb_filter = [32,64,128,256,512]
global bn_axis
K.set_image_data_format("channels_last")
bn_axis = -1
inputs = Input(shape=input_shape, name='input_image')
conv1_1 = conv_batchnorm_relu_block(inputs, nb_filter=nb_filter[0])
pool1 = AvgPool2D((2, 2), strides=(2, 2), name='pool1')(conv1_1)
conv2_1 = conv_batchnorm_relu_block(pool1, nb_filter=nb_filter[1])
pool2 = AvgPool2D((2, 2), strides=(2, 2), name='pool2')(conv2_1)
up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(conv2_1)
conv1_2 = concatenate([up1_2, conv1_1], name='merge12', axis=bn_axis)
conv1_2 = conv_batchnorm_relu_block(conv1_2, nb_filter=nb_filter[0])
conv3_1 = conv_batchnorm_relu_block(pool2, nb_filter=nb_filter[2])
pool3 = AvgPool2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)
up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
conv2_2 = concatenate([up2_2, conv2_1], name='merge22', axis=bn_axis)
conv2_2 = conv_batchnorm_relu_block(conv2_2, nb_filter=nb_filter[1])
up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
conv1_3 = concatenate([up1_3, conv1_1, conv1_2], name='merge13', axis=bn_axis)
conv1_3 = conv_batchnorm_relu_block(conv1_3, nb_filter=nb_filter[0])
conv4_1 = conv_batchnorm_relu_block(pool3, nb_filter=nb_filter[3])
pool4 = AvgPool2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)
up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
conv3_2 = concatenate([up3_2, conv3_1], name='merge32', axis=bn_axis)
conv3_2 = conv_batchnorm_relu_block(conv3_2, nb_filter=nb_filter[2])
up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
conv2_3 = concatenate([up2_3, conv2_1, conv2_2], name='merge23', axis=bn_axis)
conv2_3 = conv_batchnorm_relu_block(conv2_3, nb_filter=nb_filter[1])
up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
conv1_4 = concatenate([up1_4, conv1_1, conv1_2, conv1_3], name='merge14', axis=bn_axis)
conv1_4 = conv_batchnorm_relu_block(conv1_4, nb_filter=nb_filter[0])
conv5_1 = conv_batchnorm_relu_block(pool4, nb_filter=nb_filter[4])
up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
conv4_2 = concatenate([up4_2, conv4_1], name='merge42', axis=bn_axis)
conv4_2 = conv_batchnorm_relu_block(conv4_2, nb_filter=nb_filter[3])
up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
conv3_3 = concatenate([up3_3, conv3_1, conv3_2], name='merge33', axis=bn_axis)
conv3_3 = conv_batchnorm_relu_block(conv3_3, nb_filter=nb_filter[2])
up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
conv2_4 = concatenate([up2_4, conv2_1, conv2_2, conv2_3], name='merge24', axis=bn_axis)
conv2_4 = conv_batchnorm_relu_block(conv2_4, nb_filter=nb_filter[1])
up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
conv1_5 = concatenate([up1_5, conv1_1, conv1_2, conv1_3, conv1_4], name='merge15', axis=bn_axis)
conv1_5 = conv_batchnorm_relu_block(conv1_5, nb_filter=nb_filter[0])
nestnet_output_1 = Conv2D(n_labels, (1, 1), activation='sigmoid', name='output_1',padding='same')(conv1_2)
nestnet_output_2 = Conv2D(n_labels, (1, 1), activation='sigmoid', name='output_2', padding='same' )(conv1_3)
nestnet_output_3 = Conv2D(n_labels, (1, 1), activation='sigmoid', name='output_3', padding='same')(conv1_4)
nestnet_output_4 = Conv2D(n_labels, (1, 1), activation='sigmoid', name='output_4', padding='same')(conv1_5)
#This class will automatically save TF models
#at Epoch which are multiples of the SAVE_MULTIPLE parameter.
SAVE_MULTIPLE = 5
class ModelSaver(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if (epoch)%SAVE_MULTIPLE == 0: # Save when epochs are multiples of SAVE_MULTIPLE.
self.model.save(f"/content/drive/MyDrive/final/saved_models_Unet++/model_{epoch}.h5")
https://lars76.github.io/2018/09/27/loss-functions-for-segmentation.html
def dice_loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float32)
y_pred = tf.math.sigmoid(y_pred)
numerator = 2 * tf.reduce_sum(y_true * y_pred)
denominator = tf.reduce_sum(y_true + y_pred)
return 1 - numerator / denominator
model = Model(inputs=inputs, outputs=nestnet_output_4)
# model.compile(optimizer = Adam(learning_rate=0.0001),
# loss = dice_loss,
# metrics=['accuracy'])
model.compile(optimizer=Adam(learning_rate=0.0001), loss = tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
# create and use callback:
saver = ModelSaver()
EPOCHS = 100
STEPS_PER_EPOCH = train_size // BATCH_SIZE
VALIDATION_STEPS = val_size // BATCH_SIZE
# sometimes it can be very interesting to run some batches on cpu
# because the tracing is way better than on GPU
# you will have more obvious error message
# but in our case, it takes A LOT of time
# #On CPU
# with tf.device("/cpu:0"):
# model_history = model.fit(dataset['train'], epochs=EPOCHS,
# callbacks=[saver],
# steps_per_epoch=STEPS_PER_EPOCH,
# validation_steps=VALIDATION_STEPS,
# validation_data=dataset['val'])
# #On GPU
model_history = model.fit(dataset['train'], epochs=EPOCHS,
callbacks=[saver],
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=dataset['val'])
Epoch 1/100 6/27 [=====>........................] - ETA: 2s - loss: 0.2997 - accuracy: 0.9120WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0519s vs `on_train_batch_end` time: 0.0633s). Check your callbacks. 27/27 [==============================] - 18s 180ms/step - loss: 0.1473 - accuracy: 0.9585 - val_loss: 0.3735 - val_accuracy: 0.9804 Epoch 2/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0873 - accuracy: 0.9727 - val_loss: 0.3147 - val_accuracy: 0.9804 Epoch 3/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0722 - accuracy: 0.9736 - val_loss: 0.3152 - val_accuracy: 0.9805 Epoch 4/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0549 - accuracy: 0.9819 - val_loss: 0.2408 - val_accuracy: 0.9804 Epoch 5/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0434 - accuracy: 0.9871 - val_loss: 0.1964 - val_accuracy: 0.9811 Epoch 6/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0378 - accuracy: 0.9887 - val_loss: 0.1587 - val_accuracy: 0.9817 Epoch 7/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0345 - accuracy: 0.9895 - val_loss: 0.1132 - val_accuracy: 0.9823 Epoch 8/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0323 - accuracy: 0.9899 - val_loss: 0.0837 - val_accuracy: 0.9826 Epoch 9/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0310 - accuracy: 0.9902 - val_loss: 0.0647 - val_accuracy: 0.9828 Epoch 10/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0299 - accuracy: 0.9904 - val_loss: 0.0552 - val_accuracy: 0.9842 Epoch 11/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0286 - accuracy: 0.9907 - val_loss: 0.0483 - val_accuracy: 0.9840 Epoch 12/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0277 - accuracy: 0.9909 - val_loss: 0.0443 - val_accuracy: 0.9844 Epoch 13/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0270 - accuracy: 0.9911 - val_loss: 0.0414 - val_accuracy: 0.9850 Epoch 14/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0264 - accuracy: 0.9912 - val_loss: 0.0404 - val_accuracy: 0.9849 Epoch 15/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0259 - accuracy: 0.9913 - val_loss: 0.0407 - val_accuracy: 0.9848 Epoch 16/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0256 - accuracy: 0.9913 - val_loss: 0.0389 - val_accuracy: 0.9853 Epoch 17/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0252 - accuracy: 0.9914 - val_loss: 0.0382 - val_accuracy: 0.9855 Epoch 18/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0249 - accuracy: 0.9914 - val_loss: 0.0367 - val_accuracy: 0.9860 Epoch 19/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0246 - accuracy: 0.9915 - val_loss: 0.0362 - val_accuracy: 0.9862 Epoch 20/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0243 - accuracy: 0.9915 - val_loss: 0.0348 - val_accuracy: 0.9865 Epoch 21/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0242 - accuracy: 0.9915 - val_loss: 0.0336 - val_accuracy: 0.9868 Epoch 22/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0238 - accuracy: 0.9916 - val_loss: 0.0338 - val_accuracy: 0.9868 Epoch 23/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0235 - accuracy: 0.9917 - val_loss: 0.0344 - val_accuracy: 0.9867 Epoch 24/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0232 - accuracy: 0.9918 - val_loss: 0.0362 - val_accuracy: 0.9864 Epoch 25/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0232 - accuracy: 0.9917 - val_loss: 0.0318 - val_accuracy: 0.9875 Epoch 26/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0229 - accuracy: 0.9918 - val_loss: 0.0316 - val_accuracy: 0.9879 Epoch 27/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0226 - accuracy: 0.9919 - val_loss: 0.0319 - val_accuracy: 0.9879 Epoch 28/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0224 - accuracy: 0.9919 - val_loss: 0.0293 - val_accuracy: 0.9883 Epoch 29/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0221 - accuracy: 0.9920 - val_loss: 0.0295 - val_accuracy: 0.9883 Epoch 30/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0219 - accuracy: 0.9921 - val_loss: 0.0285 - val_accuracy: 0.9886 Epoch 31/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0216 - accuracy: 0.9922 - val_loss: 0.0302 - val_accuracy: 0.9884 Epoch 32/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0215 - accuracy: 0.9922 - val_loss: 0.0279 - val_accuracy: 0.9885 Epoch 33/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0212 - accuracy: 0.9923 - val_loss: 0.0281 - val_accuracy: 0.9885 Epoch 34/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0210 - accuracy: 0.9923 - val_loss: 0.0262 - val_accuracy: 0.9892 Epoch 35/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0206 - accuracy: 0.9924 - val_loss: 0.0260 - val_accuracy: 0.9894 Epoch 36/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0204 - accuracy: 0.9924 - val_loss: 0.0257 - val_accuracy: 0.9896 Epoch 37/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0203 - accuracy: 0.9925 - val_loss: 0.0263 - val_accuracy: 0.9895 Epoch 38/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0206 - accuracy: 0.9924 - val_loss: 0.0277 - val_accuracy: 0.9889 Epoch 39/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0200 - accuracy: 0.9926 - val_loss: 0.0294 - val_accuracy: 0.9884 Epoch 40/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0198 - accuracy: 0.9926 - val_loss: 0.0317 - val_accuracy: 0.9881 Epoch 41/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0204 - accuracy: 0.9924 - val_loss: 0.0278 - val_accuracy: 0.9888 Epoch 42/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0198 - accuracy: 0.9926 - val_loss: 0.0289 - val_accuracy: 0.9887 Epoch 43/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0195 - accuracy: 0.9927 - val_loss: 0.0292 - val_accuracy: 0.9886 Epoch 44/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0195 - accuracy: 0.9927 - val_loss: 0.0290 - val_accuracy: 0.9886 Epoch 45/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0195 - accuracy: 0.9927 - val_loss: 0.0336 - val_accuracy: 0.9879 Epoch 46/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0193 - accuracy: 0.9928 - val_loss: 0.0276 - val_accuracy: 0.9892 Epoch 47/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0193 - accuracy: 0.9927 - val_loss: 0.0323 - val_accuracy: 0.9883 Epoch 48/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0191 - accuracy: 0.9928 - val_loss: 0.0249 - val_accuracy: 0.9901 Epoch 49/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0190 - accuracy: 0.9928 - val_loss: 0.0258 - val_accuracy: 0.9899 Epoch 50/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0186 - accuracy: 0.9929 - val_loss: 0.0252 - val_accuracy: 0.9902 Epoch 51/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0183 - accuracy: 0.9931 - val_loss: 0.0257 - val_accuracy: 0.9901 Epoch 52/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0180 - accuracy: 0.9932 - val_loss: 0.0264 - val_accuracy: 0.9900 Epoch 53/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0177 - accuracy: 0.9933 - val_loss: 0.0256 - val_accuracy: 0.9901 Epoch 54/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0176 - accuracy: 0.9933 - val_loss: 0.0285 - val_accuracy: 0.9891 Epoch 55/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0174 - accuracy: 0.9934 - val_loss: 0.0271 - val_accuracy: 0.9896 Epoch 56/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0173 - accuracy: 0.9934 - val_loss: 0.0278 - val_accuracy: 0.9895 Epoch 57/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0171 - accuracy: 0.9935 - val_loss: 0.0273 - val_accuracy: 0.9896 Epoch 58/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0169 - accuracy: 0.9936 - val_loss: 0.0288 - val_accuracy: 0.9893 Epoch 59/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0166 - accuracy: 0.9937 - val_loss: 0.0263 - val_accuracy: 0.9900 Epoch 60/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0165 - accuracy: 0.9937 - val_loss: 0.0245 - val_accuracy: 0.9905 Epoch 61/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0160 - accuracy: 0.9940 - val_loss: 0.0277 - val_accuracy: 0.9895 Epoch 62/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0159 - accuracy: 0.9940 - val_loss: 0.0289 - val_accuracy: 0.9893 Epoch 63/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0155 - accuracy: 0.9942 - val_loss: 0.0256 - val_accuracy: 0.9903 Epoch 64/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0151 - accuracy: 0.9943 - val_loss: 0.0269 - val_accuracy: 0.9902 Epoch 65/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0150 - accuracy: 0.9944 - val_loss: 0.0290 - val_accuracy: 0.9896 Epoch 66/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0147 - accuracy: 0.9945 - val_loss: 0.0315 - val_accuracy: 0.9893 Epoch 67/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0145 - accuracy: 0.9946 - val_loss: 0.0274 - val_accuracy: 0.9903 Epoch 68/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0143 - accuracy: 0.9946 - val_loss: 0.0292 - val_accuracy: 0.9897 Epoch 69/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0140 - accuracy: 0.9948 - val_loss: 0.0333 - val_accuracy: 0.9887 Epoch 70/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0136 - accuracy: 0.9950 - val_loss: 0.0293 - val_accuracy: 0.9898 Epoch 71/100 27/27 [==============================] - 4s 153ms/step - loss: 0.0132 - accuracy: 0.9951 - val_loss: 0.0288 - val_accuracy: 0.9900 Epoch 72/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0128 - accuracy: 0.9953 - val_loss: 0.0276 - val_accuracy: 0.9903 Epoch 73/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0123 - accuracy: 0.9955 - val_loss: 0.0277 - val_accuracy: 0.9903 Epoch 74/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0117 - accuracy: 0.9958 - val_loss: 0.0293 - val_accuracy: 0.9901 Epoch 75/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0113 - accuracy: 0.9960 - val_loss: 0.0285 - val_accuracy: 0.9903 Epoch 76/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0110 - accuracy: 0.9961 - val_loss: 0.0289 - val_accuracy: 0.9903 Epoch 77/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0108 - accuracy: 0.9962 - val_loss: 0.0322 - val_accuracy: 0.9897 Epoch 78/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0109 - accuracy: 0.9961 - val_loss: 0.0311 - val_accuracy: 0.9901 Epoch 79/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0108 - accuracy: 0.9961 - val_loss: 0.0307 - val_accuracy: 0.9902 Epoch 80/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0107 - accuracy: 0.9962 - val_loss: 0.0331 - val_accuracy: 0.9898 Epoch 81/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0110 - accuracy: 0.9961 - val_loss: 0.0361 - val_accuracy: 0.9894 Epoch 82/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0113 - accuracy: 0.9959 - val_loss: 0.0360 - val_accuracy: 0.9894 Epoch 83/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0108 - accuracy: 0.9961 - val_loss: 0.0323 - val_accuracy: 0.9900 Epoch 84/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0101 - accuracy: 0.9965 - val_loss: 0.0316 - val_accuracy: 0.9901 Epoch 85/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0095 - accuracy: 0.9968 - val_loss: 0.0287 - val_accuracy: 0.9904 Epoch 86/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0087 - accuracy: 0.9971 - val_loss: 0.0280 - val_accuracy: 0.9906 Epoch 87/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0084 - accuracy: 0.9972 - val_loss: 0.0289 - val_accuracy: 0.9903 Epoch 88/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0081 - accuracy: 0.9974 - val_loss: 0.0284 - val_accuracy: 0.9907 Epoch 89/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0078 - accuracy: 0.9975 - val_loss: 0.0302 - val_accuracy: 0.9905 Epoch 90/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0076 - accuracy: 0.9976 - val_loss: 0.0314 - val_accuracy: 0.9903 Epoch 91/100 27/27 [==============================] - 4s 151ms/step - loss: 0.0074 - accuracy: 0.9977 - val_loss: 0.0330 - val_accuracy: 0.9900 Epoch 92/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0073 - accuracy: 0.9977 - val_loss: 0.0325 - val_accuracy: 0.9900 Epoch 93/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0072 - accuracy: 0.9978 - val_loss: 0.0300 - val_accuracy: 0.9904 Epoch 94/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0071 - accuracy: 0.9978 - val_loss: 0.0319 - val_accuracy: 0.9902 Epoch 95/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0070 - accuracy: 0.9978 - val_loss: 0.0336 - val_accuracy: 0.9902 Epoch 96/100 27/27 [==============================] - 4s 152ms/step - loss: 0.0066 - accuracy: 0.9980 - val_loss: 0.0320 - val_accuracy: 0.9904 Epoch 97/100 27/27 [==============================] - 4s 136ms/step - loss: 0.0066 - accuracy: 0.9980 - val_loss: 0.0331 - val_accuracy: 0.9905 Epoch 98/100 27/27 [==============================] - 4s 139ms/step - loss: 0.0064 - accuracy: 0.9980 - val_loss: 0.0327 - val_accuracy: 0.9906 Epoch 99/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0064 - accuracy: 0.9980 - val_loss: 0.0344 - val_accuracy: 0.9902 Epoch 100/100 27/27 [==============================] - 4s 137ms/step - loss: 0.0061 - accuracy: 0.9981 - val_loss: 0.0324 - val_accuracy: 0.9906
## Save the model at the latest EPOCH or as desired
model.save(f"/content/drive/MyDrive/final/saved_models_Unet++/model_{EPOCHS}.h5")
MODEL_PATH = "/content/drive/MyDrive/final/saved_models_Unet++/model_95.h5" #Change this to the model path you want to load
model.load_weights(MODEL_PATH, by_name=True)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_image (InputLayer) [(None, 256, 256, 3 0 []
)]
conv2d (Conv2D) (None, 256, 256, 32 896 ['input_image[0][0]']
)
batch_normalization (BatchNorm (None, 256, 256, 32 1024 ['conv2d[0][0]']
alization) )
activation (Activation) (None, 256, 256, 32 0 ['batch_normalization[0][0]']
)
pool1 (AveragePooling2D) (None, 128, 128, 32 0 ['activation[0][0]']
)
conv2d_1 (Conv2D) (None, 128, 128, 64 18496 ['pool1[0][0]']
)
batch_normalization_1 (BatchNo (None, 128, 128, 64 512 ['conv2d_1[0][0]']
rmalization) )
activation_1 (Activation) (None, 128, 128, 64 0 ['batch_normalization_1[0][0]']
)
pool2 (AveragePooling2D) (None, 64, 64, 64) 0 ['activation_1[0][0]']
conv2d_3 (Conv2D) (None, 64, 64, 128) 73856 ['pool2[0][0]']
batch_normalization_3 (BatchNo (None, 64, 64, 128) 256 ['conv2d_3[0][0]']
rmalization)
activation_3 (Activation) (None, 64, 64, 128) 0 ['batch_normalization_3[0][0]']
pool3 (AveragePooling2D) (None, 32, 32, 128) 0 ['activation_3[0][0]']
conv2d_6 (Conv2D) (None, 32, 32, 256) 295168 ['pool3[0][0]']
batch_normalization_6 (BatchNo (None, 32, 32, 256) 128 ['conv2d_6[0][0]']
rmalization)
activation_6 (Activation) (None, 32, 32, 256) 0 ['batch_normalization_6[0][0]']
pool4 (AveragePooling2D) (None, 16, 16, 256) 0 ['activation_6[0][0]']
conv2d_10 (Conv2D) (None, 16, 16, 512) 1180160 ['pool4[0][0]']
batch_normalization_10 (BatchN (None, 16, 16, 512) 64 ['conv2d_10[0][0]']
ormalization)
activation_10 (Activation) (None, 16, 16, 512) 0 ['batch_normalization_10[0][0]']
up42 (Conv2DTranspose) (None, 32, 32, 256) 524544 ['activation_10[0][0]']
merge42 (Concatenate) (None, 32, 32, 512) 0 ['up42[0][0]',
'activation_6[0][0]']
up32 (Conv2DTranspose) (None, 64, 64, 128) 131200 ['activation_6[0][0]']
conv2d_11 (Conv2D) (None, 32, 32, 256) 1179904 ['merge42[0][0]']
merge32 (Concatenate) (None, 64, 64, 256) 0 ['up32[0][0]',
'activation_3[0][0]']
up22 (Conv2DTranspose) (None, 128, 128, 64 32832 ['activation_3[0][0]']
)
batch_normalization_11 (BatchN (None, 32, 32, 256) 128 ['conv2d_11[0][0]']
ormalization)
conv2d_7 (Conv2D) (None, 64, 64, 128) 295040 ['merge32[0][0]']
merge22 (Concatenate) (None, 128, 128, 12 0 ['up22[0][0]',
8) 'activation_1[0][0]']
up12 (Conv2DTranspose) (None, 256, 256, 32 8224 ['activation_1[0][0]']
)
activation_11 (Activation) (None, 32, 32, 256) 0 ['batch_normalization_11[0][0]']
batch_normalization_7 (BatchNo (None, 64, 64, 128) 256 ['conv2d_7[0][0]']
rmalization)
conv2d_4 (Conv2D) (None, 128, 128, 64 73792 ['merge22[0][0]']
)
merge12 (Concatenate) (None, 256, 256, 64 0 ['up12[0][0]',
) 'activation[0][0]']
up33 (Conv2DTranspose) (None, 64, 64, 128) 131200 ['activation_11[0][0]']
activation_7 (Activation) (None, 64, 64, 128) 0 ['batch_normalization_7[0][0]']
batch_normalization_4 (BatchNo (None, 128, 128, 64 512 ['conv2d_4[0][0]']
rmalization) )
conv2d_2 (Conv2D) (None, 256, 256, 32 18464 ['merge12[0][0]']
)
merge33 (Concatenate) (None, 64, 64, 384) 0 ['up33[0][0]',
'activation_3[0][0]',
'activation_7[0][0]']
activation_4 (Activation) (None, 128, 128, 64 0 ['batch_normalization_4[0][0]']
)
up23 (Conv2DTranspose) (None, 128, 128, 64 32832 ['activation_7[0][0]']
)
batch_normalization_2 (BatchNo (None, 256, 256, 32 1024 ['conv2d_2[0][0]']
rmalization) )
conv2d_12 (Conv2D) (None, 64, 64, 128) 442496 ['merge33[0][0]']
merge23 (Concatenate) (None, 128, 128, 19 0 ['up23[0][0]',
2) 'activation_1[0][0]',
'activation_4[0][0]']
activation_2 (Activation) (None, 256, 256, 32 0 ['batch_normalization_2[0][0]']
)
up13 (Conv2DTranspose) (None, 256, 256, 32 8224 ['activation_4[0][0]']
)
batch_normalization_12 (BatchN (None, 64, 64, 128) 256 ['conv2d_12[0][0]']
ormalization)
conv2d_8 (Conv2D) (None, 128, 128, 64 110656 ['merge23[0][0]']
)
merge13 (Concatenate) (None, 256, 256, 96 0 ['up13[0][0]',
) 'activation[0][0]',
'activation_2[0][0]']
activation_12 (Activation) (None, 64, 64, 128) 0 ['batch_normalization_12[0][0]']
batch_normalization_8 (BatchNo (None, 128, 128, 64 512 ['conv2d_8[0][0]']
rmalization) )
conv2d_5 (Conv2D) (None, 256, 256, 32 27680 ['merge13[0][0]']
)
up24 (Conv2DTranspose) (None, 128, 128, 64 32832 ['activation_12[0][0]']
)
activation_8 (Activation) (None, 128, 128, 64 0 ['batch_normalization_8[0][0]']
)
batch_normalization_5 (BatchNo (None, 256, 256, 32 1024 ['conv2d_5[0][0]']
rmalization) )
merge24 (Concatenate) (None, 128, 128, 25 0 ['up24[0][0]',
6) 'activation_1[0][0]',
'activation_4[0][0]',
'activation_8[0][0]']
activation_5 (Activation) (None, 256, 256, 32 0 ['batch_normalization_5[0][0]']
)
up14 (Conv2DTranspose) (None, 256, 256, 32 8224 ['activation_8[0][0]']
)
conv2d_13 (Conv2D) (None, 128, 128, 64 147520 ['merge24[0][0]']
)
merge14 (Concatenate) (None, 256, 256, 12 0 ['up14[0][0]',
8) 'activation[0][0]',
'activation_2[0][0]',
'activation_5[0][0]']
batch_normalization_13 (BatchN (None, 128, 128, 64 512 ['conv2d_13[0][0]']
ormalization) )
conv2d_9 (Conv2D) (None, 256, 256, 32 36896 ['merge14[0][0]']
)
activation_13 (Activation) (None, 128, 128, 64 0 ['batch_normalization_13[0][0]']
)
batch_normalization_9 (BatchNo (None, 256, 256, 32 1024 ['conv2d_9[0][0]']
rmalization) )
up15 (Conv2DTranspose) (None, 256, 256, 32 8224 ['activation_13[0][0]']
)
activation_9 (Activation) (None, 256, 256, 32 0 ['batch_normalization_9[0][0]']
)
merge15 (Concatenate) (None, 256, 256, 16 0 ['up15[0][0]',
0) 'activation[0][0]',
'activation_2[0][0]',
'activation_5[0][0]',
'activation_9[0][0]']
conv2d_14 (Conv2D) (None, 256, 256, 32 46112 ['merge15[0][0]']
)
batch_normalization_14 (BatchN (None, 256, 256, 32 1024 ['conv2d_14[0][0]']
ormalization) )
activation_14 (Activation) (None, 256, 256, 32 0 ['batch_normalization_14[0][0]']
)
output_4 (Conv2D) (None, 256, 256, 2) 66 ['activation_14[0][0]']
==================================================================================================
Total params: 4,873,794
Trainable params: 4,869,666
Non-trainable params: 4,128
__________________________________________________________________________________________________
import sys
np.set_printoptions(threshold=sys.maxsize)
#Pick first batch of Image from Validation set
for image, mask in dataset['val'].take(1):
sample_image, sample_mask = image, mask
#Pick first Prediction Validation set
sample_idx = 2
pred_mask = model.predict(sample_image)
sample_image = sample_image.numpy()[sample_idx] #Pick first image form a batch of 5
sample_image = (sample_image*255.0).astype(np.uint32)
#Ground Truth Mask
sample_mask = sample_mask.numpy()[sample_idx]
sample_mask = np.squeeze(sample_mask, axis =-1)
#Predicted Mask
pred_mask = pred_mask[sample_idx] #Pick first mask form a batch of 5
print(sample_image.shape)
print(pred_mask.shape)
(256, 256, 3) (256, 256, 2)
pred_mask = np.argmax(pred_mask, axis =-1)
#pred_mask = np.expand_dims(pred_mask, axis =-1)
print(sample_image.shape)
print(pred_mask.shape)
(256, 256, 3) (256, 256)
colors = visualize.random_colors(10)
color = colors[0]
print(color)
masked_image = sample_image.astype(np.uint32).copy()
masked_image = visualize.apply_mask(masked_image, pred_mask, color)
visualize.display_images([masked_image])
(0.7999999999999998, 1.0, 0.0)
# Vizualize Ground Truth
#GT
colors = visualize.random_colors(10)
color = colors[0]
print(color)
masked_image = sample_image.astype(np.uint32).copy()
print(masked_image.shape)
print(sample_mask.shape)
masked_image = visualize.apply_mask(masked_image, sample_mask, color)
visualize.display_images([masked_image])
(0.1999999999999993, 0.0, 1.0) (256, 256, 3) (256, 256)
import pprint
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
count = 0
# for image, mask in dataset['val'].take(1):
# sample_image, sample_mask = image, mask
for image, mask in dataset['val']:
if count > 4:
break
sample_image, sample_mask = image, mask
#Pick first Prediction Validation set
sample_idx = 3
pred_mask = model.predict(sample_image)
sample_image = sample_image.numpy()[sample_idx] #Pick first image form a batch of 5
sample_image = (sample_image*255.0).astype(np.uint32)
#Ground Truth Mask
sample_mask = sample_mask.numpy()[sample_idx]
sample_mask = np.squeeze(sample_mask, axis =-1)
#Predicted Mask
pred_mask = pred_mask[sample_idx] #Pick first mask form a batch of 5
pred_mask = np.argmax(pred_mask, axis =-1)
colors = visualize.random_colors(10)
color = colors[0]
color2 = colors[1]
Input_image = sample_image.astype(np.uint32).copy()
True_Mask = sample_image.astype(np.uint32).copy()
Pred_Mask = sample_image.astype(np.uint32).copy()
True_Mask = visualize.apply_mask(True_Mask, sample_mask, color2)
Pred_Mask = visualize.apply_mask(Pred_Mask, pred_mask, color)
#visualize.display_images([masked_image])
display_sample([Input_image,True_Mask,Pred_Mask])
count += 1
count = 0
avg_pixel = []
avg_IOU = []
for image, mask in dataset['val']:
if count > 4:
break
sample_image, sample_mask = image, mask
#Pick first Prediction Validation set
sample_idx = 3
pred_mask = model.predict(sample_image)
sample_image = sample_image.numpy()[sample_idx] #Pick first image form a batch of 5
sample_image = (sample_image*255.0).astype(np.uint32)
#Ground Truth Mask
sample_mask = sample_mask.numpy()[sample_idx]
sample_mask = np.squeeze(sample_mask, axis =-1)
print
#Predicted Mask
pred_mask = pred_mask[sample_idx] #Pick first mask form a batch of 5
pred_mask = np.argmax(pred_mask, axis =-1)
true = copy.deepcopy(sample_mask)
pred = copy.deepcopy(pred_mask)
intersection = np.logical_and(sample_mask, pred_mask)
union = np.logical_or(sample_mask, pred_mask)
iou_score = np.sum(intersection) / np.sum(union)
avg_IOU.append(iou_score)
print("IOU: " + str(iou_score))
# True Positive (TP): we predict a label of 1 (positive), and the true label is 1.
TP = np.sum(np.logical_and(pred == 1, true == 1))
# True Negative (TN): we predict a label of 0 (negative), and the true label is 0.
TN = np.sum(np.logical_and(pred == 0, true == 0))
# False Positive (FP): we predict a label of 1 (positive), but the true label is 0.
FP = np.sum(np.logical_and(pred == 1, true == 0))
# False Negative (FN): we predict a label of 0 (negative), but the true label is 1.
FN = np.sum(np.logical_and(pred == 0, true == 1))
pixel_accuracy = (TP+TN)/(TP+FP+TN+FN)
print("pixel_accuracy: " +str(pixel_accuracy))
avg_pixel.append(pixel_accuracy)
count += 1
IOU: 0.5501432664756447 pixel_accuracy: 0.99041748046875 IOU: 0.49888641425389757 pixel_accuracy: 0.9965667724609375 IOU: 0.6440281030444965 pixel_accuracy: 0.9976806640625 IOU: 0.5501432664756447 pixel_accuracy: 0.99041748046875 IOU: 0.49888641425389757 pixel_accuracy: 0.9965667724609375
print("Average Pixel Accuracy: " + str(sum(avg_pixel)/len(avg_pixel)))
print("Average IOU: " + str(sum(avg_IOU)/len(avg_IOU)))
Average Pixel Accuracy: 0.994329833984375 Average IOU: 0.5484174929007162